"""
Import Needed Packages
"""
import numpy as np
import pandas as pd
from RFS_Functions.Lininterp import *
from csv import writer


"""
Structural Estimation
"""


def append_list_as_row(file_name, list_of_elem):
    """
    Open the outcome csv file and append the grid search results
    :param file_name: outcome file name
    :param list_of_elem: results to be appended
    :return: append and save results locally
    """
    # Open file in append mode
    with open(file_name, 'a+', newline='') as write_obj:
        # Create a writer object from csv module
        csv_writer = writer(write_obj)
        # Add contents of list as last row in the csv file
        csv_writer.writerow(list_of_elem)


class Structural(object):
    def __init__(self, r, w_ai, l_ai, w_ot, l_ot, w_dm, l_dm, yearmonth, emp,
                 FE='both', FE_tol=1e-10, AI_emp='', OT_emp=''):
        """
        Create a structural object to estimate structural parameters of interest from employment
        information for different aggregation levels.
                            - 'i' -> Computation based on time-series averages by firm (Nx1 vector)
                           - 't' -> Computation based on cross-sectional averages per month (Tx1 vector)
                           - 'i,t' -> Computation based on cross-sectional and time-series averages (scalar)
        :param r: monthly discount rate
        :param w_ai: average salary for all AI workers in a given yearmonth
        :param l_ai: number of AI workers varying by emp and yearmonth
        :param w_ot: average salary for all old-tech workers in a given yearmonth
        :param l_ot: number of old-tech workers varying by emp and yearmonth
        :param w_dm: average salary for all data-management workers in a given yearmonth
        :param l_dm: number of data-management workers varying by emp and yearmonth
        :param yearmonth: date of observations as of monthend
        :param emp: unique employer identifier
        :param FE: (str) Options:
                                  - A: Only considers time specific FEs
                                  - a: Only consider firm specific FEs
                                  - both: consider both time and firm specific FEs
                                  - none: does not consider any FEs
        """
        self.r = r
        self.w_ai = w_ai
        self.l_ai = l_ai
        self.w_ot = w_ot
        self.l_ot = l_ot
        self.w_dm = w_dm
        self.l_dm = l_dm
        self.yearmonth = yearmonth
        self.emp = emp
        self.FE = FE
        self.FE_tol = FE_tol
        self.AI_emp = AI_emp
        self.OT_emp = OT_emp
        if self.AI_emp == '':
            self.AI_emp = len(self.emp)
        if self.OT_emp == '':
            self.OT_emp = len(self.emp)

    def dit_inner(self, alpha, gamma, phi, delta, name):
        """
        Calculate a basic intermediate step(d_it) to get d_it_rec
        :param alpha: production function exponents on data for AI type
        :param gamma: production function exponents on data for old tech type
        :param phi: production function exponents on data for data management type
        :param delta: depreciation rate
        :param name: sample name
        :return: inner part d_it
        """
        sub_d_it_var = (((((alpha / (1 - alpha)) * self.w_ai * self.l_ai) +
                          ((gamma / (1 - gamma)) * self.w_ot * self.l_ot)) * (1 - phi)) /
                        (self.r - 1 + delta))
        sub_d_it_var = sub_d_it_var.rename('sub_' + name)
        return sub_d_it_var

    def d_it_rec_calc(self, df, phi, delta, d_0, job_name):
        """
        Calculate two decomposition parts of d_it_rec: rec_it and zero_it
        :param df: employer group
        :param phi: production function exponents on data for data management type
        :param delta: depreciation rate
        :param d_0: initial data
        :param job_name: job type, e.g. data management
        :return: a dataframe that saves rec_it and zero_it
        """
        t_max = len(df)
        tt = np.arange(0, len(df), 1) + 1
        # Calculation of a Linear Operator Matrix
        op = np.zeros((t_max + 1, t_max + 1))
        for i in range(t_max + 1):
            for j in range(1, t_max + 1):
                if i >= j:
                    op[i][j - 1] = (1 - delta) ** (i - j)
        op = np.array(op)
        op = np.delete(op, 0, axis=0)
        op = np.delete(op, -1, axis=1)
        # Final Data Calculation
        df['rec_it'] = np.matmul(op, df[job_name].to_numpy() ** (1 - phi))
        df['zero_it'] = (((1 - delta) ** tt) * d_0)
        return df

    def rec_d_it_emp(self, grp, phi, delta, ups):
        """
        Intermediate step to calculate the recursive data for firm i at time t given initial data stock
        :param grp: employer group
        :param phi: production function exponents on data for data management type
        :param delta: depreciation rate
        :param ups: a chosen constant so that the average initial data stock is the estimated average stock
        :return: rec_it and zero_it
        """
        if grp['year'].to_numpy()[0] == 2015:
            d_i0 = ups * (grp.l_dm.to_numpy()[0] ** (1 - phi))
        else:
            d_i0 = 0.0001
        grp = self.d_it_rec_calc(grp, phi, delta, d_i0, 'l_dm')
        return grp

    def d_i0_calc(self, grp, phi, ups):
        """
        Calculate initial data stock
        :param grp: employer group
        :param phi: production function exponents on data for data management type
        :param ups: a chosen constant so that the average initial data stock is the estimated average stock
        :return: initial data stock for firm i
        """
        if grp['year'].to_numpy()[0] == 2015:
            grp['d_i0'] = ups * (grp.l_dm.to_numpy()[0] ** (1 - phi))
        else:
            grp['d_i0'] = 0.0001
        return grp

    def dit_emp(self, d_0_av, phi, delta, zero=False):
        """
        Calculate data stock for firm i in time t
        :param d_0_av: average initial data stock at firm level
        :param phi: production function exponents on data for data management type
        :param delta: depreciation rate
        :param zero: indicator about whether it's time 0
        :return: recursive data stock for firm i in time t
        """
        df = pd.merge(self.yearmonth, self.emp, left_index=True, right_index=True)
        df = df.merge(self.l_dm, left_index=True, right_index=True)
        df['ii'] = df.index
        df['year'] = df['yearmonth'].dt.year
        start_sum = (df[df.year == 2015].groupby('emp')['l_dm'].first() ** (1 - phi)).mean()
        ups = d_0_av/start_sum
        if zero:
            a = df.groupby('emp').apply(lambda grp: self.d_i0_calc(grp, phi, ups))
            a = a.sort_values('ii').set_index('ii')
            a['ups'] = ups
            return a.ups, a.d_i0
        else:
            a = df.groupby('emp').apply(lambda grp: self.rec_d_it_emp(grp, phi, delta, ups))
            a = a.sort_values('ii').set_index('ii')
            return a.zero_it, a.rec_it

    def _level_grouping(self, df, level):
        """
        Aggregates two vectors in a DataFrame according to the chosen level of aggregation
        :param df: DataFrame containing ['yearmonth', 'emp', 'a', 'b']; where 'a' and 'b' are the columns to be
                   aggregated
        :param level: (str) Available levels:
                           - 'i' -> Computation based on time-series averages by firm (Nx1 vector)
                           - 't' -> Computation based on cross-sectional averages per month (Tx1 vector)
                           - 'i,t' -> Computation based on cross-sectional and time-series averages (scalar)
        :return: numpy vectors a and b after aggregation
        """
        if all(elem in df.columns for elem in ['yearmonth', 'emp', 'a', 'b']):
            if level == 'i':
                a = df.groupby('emp')['a'].sum()
                a.loc[a == 0] = 0.000000001
                b = df.groupby('emp')['b'].sum()
                return a, b
            elif level == 't':
                a = df.groupby('yearmonth')['a'].sum()
                a.loc[a == 0] = 0.000000001
                b = df.groupby('yearmonth')['b'].sum()
                return a, b
            else:
                raise ValueError('The variable <level> must be set to "i" or "t"')
        else:
            raise ValueError("The input DataFrame for the function <level_grouping> must contain the following columns:"
                             "['yearmonth', 'emp', 'a', 'b']")


    def data_process_calc(self, d_it_rec, sub_d_it, phi):
        """
        Compute residuals from data process condition
        :param d_it_rec: recursive data stock for firm i in time t
        :param sub_d_it: intermediate step in the calculation of D_it from the data-management FOC
        :param phi: production function exponents on data for data management type
        :return: the difference between d_it and d_it_rec
        """
        df = pd.merge(self.yearmonth, self.emp, left_index=True, right_index=True)
        df = df.merge(d_it_rec, left_index=True, right_index=True)
        df = df.merge(sub_d_it, left_index=True, right_index=True)
        df = df.merge(self.l_dm, left_index=True, right_index=True)
        df = df.merge(self.w_dm, left_index=True, right_index=True)
        df['d_it'] = df['sub_d_it'] * (df['l_dm']**(- phi)) * (df['w_dm']**(-1))
        res = df['d_it'] - df['d_it_rec']
        return res

    def productivity(self, d, l, w, param, name, level='t', f=''):
        """
        Retrieves the productivity parameter from Cobb-Douglas production function
        :param d: level of accumulated data varying by emp and yearmonth
        :param l: number of workers of a given type varying by emp and yearmonth
        :param w: average salary for all employers of a given type in each yearmonth
        :param param: parameter regulating diminishing returns of cobb-douglas function
        :param f: parameter regulating diminishing returns of cobb-douglas function
        :param name: vector of fixed effects
        :param level: (str) Available levels:
                           - 'i' -> Computation based on time-series averages by employer (Nx1 vector)
                           - 't' -> Computation based on cross-sections averages per month (Tx1 vector)
        :return: productivity parameter (scalar) or numpy vector varying by yearmonth or emp
        """
        df = pd.merge(self.yearmonth, self.emp, left_index=True, right_index=True)
        df = df.merge(l, left_index=True, right_index=True)
        df = df.merge(d, left_index=True, right_index=True)
        df = df.merge(w, left_index=True, right_index=True)
        if (self.FE == 'both'):
            if type(f) == str:
                raise ValueError('ERROR in productivity calculation. '
                                 'The fixed effects vector must be defined when FE == both')
            df = df.merge(f, left_index=True, right_index=True)
        df1 = df[(df[df.columns[2]] > 0)]
        if (self.FE == 'both'):
            df1['a'] = ((1 - param) * df1[df1.columns[5]] * (df1[df1.columns[3]] ** param) *
                        (df1[df1.columns[2]] ** (1 - param)))
        else:
            df1['a'] = ((1 - param) * (df1[df1.columns[3]] ** param) * (df1[df1.columns[2]] ** (1 - param)))
        df1['b'] = (df1[df1.columns[4]] * df1[df1.columns[2]])
        a, b = self._level_grouping(df1, level)
        a = a.reset_index()
        b = b.reset_index()
        if level == 't':
            c = pd.merge(a, b, on='yearmonth')
            c[name] = c.b / c.a
            df = df.merge(c[['yearmonth', name]], on='yearmonth', how='left')
            return df[[name]]
        elif level == 'i':
            c = pd.merge(a, b, on='emp')
            c[name] = c.b / c.a
            df = df.merge(c[['emp', name]], on='emp', how='left')
            if self.FE == 'both':
                return df[[name]], df[['emp', name]]
            else:
                return df[[name]]

    def cobb_douglas(self, a, d, l, w, param, name):
        """
        Computes the residuals of the first order condition from the Cobb Douglas production function
        :param a: productivity of knowledge
        :param d: data process
        :param l: labor
        :param w: wage
        :param param: parameter
        :param: name
        :return: residuals from first order condition
        """
        df = pd.merge(l, d, left_index=True, right_index=True)
        df = df.merge(w, left_index=True, right_index=True)
        df = df.merge(a, left_index=True, right_index=True)
        df[name] = df[df.columns[3]] * (df[df.columns[1]] ** param) * (df[df.columns[0]] ** (1 - param))
        return df[[name]]

    def cobb_douglas_fe(self, a, d, l, w, f, param, name):
        """
        Computes the residuals of the first order condition from the Cobb Douglas production function
        :param a: productivity of knowledge
        :param d: data process
        :param l: labor
        :param w: wage
        :param param: parameter
        :param: name
        :return: residuals from first order condition
        """
        df = pd.merge(l, d, left_index=True, right_index=True)
        df = df.merge(w, left_index=True, right_index=True)
        df = df.merge(a, left_index=True, right_index=True)
        df = df.merge(f, left_index=True, right_index=True)
        df[name] = df[df.columns[3]] * df[df.columns[4]] * \
                   (df[df.columns[1]] ** param) * (df[df.columns[0]] ** (1 - param))
        return df[[name]]

    def foc_cobb_douglas(self, a, d, l, w, param):
        """
        Computes the residuals of the first order condition from the Cobb Douglas production function
        :param a: productivity of knowledge
        :param d: data process
        :param l: labor
        :param w: wage
        :param param: parameter
        :return: residuals from first order condition
        """
        df = pd.merge(l, d, left_index=True, right_index=True)
        df = df.merge(w, left_index=True, right_index=True)
        df = df.merge(a, left_index=True, right_index=True)
        df = df[df[df.columns[0]] > 0]
        foc = (1 - param) * df[df.columns[3]] * (df[df.columns[1]] ** param) * (df[df.columns[0]] ** (1 - param)) - \
              (df[df.columns[2]] * df[df.columns[0]])
        return foc

    def foc_cobb_douglas_FE(self, a, d, l, w, f, param):
        """
        Computes the residuals of the first order condition from the Cobb Douglas production function
        :param a: productivity of knowledge
        :param d: data process
        :param l: labor
        :param w: wage
        :param f: fixed-effect parameter
        :param param: parameter
        :return: residuals from first order condition
        """
        df = pd.merge(l, d, left_index=True, right_index=True)
        df = df.merge(w, left_index=True, right_index=True)
        df = df.merge(a, left_index=True, right_index=True)
        df = df.merge(f, left_index=True, right_index=True)
        df = df[df[df.columns[0]] > 0]
        foc = (1 - param) * df[df.columns[3]] * df[df.columns[4]] * (df[df.columns[1]] ** param) * \
              (df[df.columns[0]] ** (1 - param)) - (df[df.columns[2]] * df[df.columns[0]])
        return foc

    def model_residuals(self, params, Debug=False):
        """
        Given labor, salary and discount rate information it produces the residuals vector to be minimized
        :param params: set-up and structure of parameters to be optimized
        :return: numpy vector with all residuals to be minimized
        """
        """
        Initialize parameters
        """
        alpha = params['alpha']
        gamma = params['gamma']
        phi = params['phi']
        delta = params['delta']
        d_0_av = params['d_0_av']
        """
        1. Calculate basic intermediate steps
        """
        sub_d_it = self.dit_inner(alpha, gamma, phi, delta, 'd_it')
        zero_it, rec_it = self.dit_emp(d_0_av, phi, delta)
        """
        2. Compute D_it_recursive  
        """
        d_it_rec = zero_it + rec_it
        d_it_rec = d_it_rec.rename('d_it_rec')
        """
        3. Compute residuals from data process condition
             - <level> here can be: 'i', 't', 'none'
        """
        data_process = self.data_process_calc(d_it_rec, sub_d_it, phi)
        """
        4. Compute data-management FOC residuals
        """
        foc_dm = (sub_d_it * (((self.l_dm**(1 - phi))) / d_it_rec)) - (self.w_dm * self.l_dm)
        if self.FE == 'both':
            """
            5. Compute A_AI and A_OT jointly with FE_AI and FE_OT
                - <level> here can be: 'i', 't', 'i,t'
                - <mode> here can be 'av', 'lin'
            """
            # Initialize fixed-effects
            fe_ai = (self.emp * 0 + 1 / self.AI_emp).rename('fe_ai')
            fe_ot = (self.emp * 0 + 1 / self.OT_emp).rename('fe_ot')
            # Compute initial productivity parameters
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t', f=fe_ai)
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t', f=fe_ot)
            # Initialize difference parameters
            f_ai_old = fe_ai + 1000
            f_ot_old = fe_ot + 1000
            a_ai_old = a_ai + 1000
            a_ot_old = a_ot + 1000
            diff = ((fe_ai - f_ai_old)**2).sum() + ((fe_ot - f_ot_old)**2).sum() + \
                   float(((a_ai - a_ai_old)**2).sum()) + float(((a_ot - a_ot_old)**2).sum())
            # Update fixed effects and productivity parameters iteratively
            i = 0
            while diff > self.FE_tol:
                fe_ai, df_fe_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'fe_ai', level='i', f=a_ai)
                a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t', f=fe_ai)
                fe_ot, df_fe_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot, gamma, 'fe_ot', level='i', f=a_ot)
                a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot, gamma, 'a_ot', level='t', f=fe_ot)
                if i == 0:
                    diff = ((fe_ai.fe_ai - f_ai_old.reset_index().fe_ai) ** 2).sum() + \
                           float(((a_ai - a_ai_old) ** 2).sum()) + \
                           ((fe_ot.fe_ot - f_ot_old.reset_index().fe_ot) ** 2).sum() + \
                           float(((a_ot - a_ot_old) ** 2).sum())
                else:
                    diff = float(((fe_ai - f_ai_old) ** 2).sum()) + float(((a_ai - a_ai_old) ** 2).sum()) + \
                           float(((fe_ot - f_ot_old) ** 2).sum()) + float(((a_ot - a_ot_old) ** 2).sum())
                f_ai_old = fe_ai
                a_ai_old = a_ai
                f_ot_old = fe_ot
                a_ot_old = a_ot
                i += 1
            if Debug:
                print('Time and firm FE iteration converged in', i, 'trials. Difference:', diff)
                foc_ai_pre = self.foc_cobb_douglas_FE(a_ai, d_it_rec, self.l_ai, self.w_ai, fe_ai, alpha)
                foc_ot_pre = self.foc_cobb_douglas_FE(a_ot, d_it_rec, self.l_ot, self.w_ot, fe_ot, gamma)
            sum_fe_ai = float(df_fe_ai.drop_duplicates('emp').fe_ai.sum())
            fe_ai = fe_ai / sum_fe_ai
            a_ai = a_ai * sum_fe_ai
            sum_fe_ot = float(df_fe_ot.drop_duplicates('emp').fe_ot.sum())
            fe_ot = fe_ot / sum_fe_ot
            a_ot = a_ot * sum_fe_ot
            """
            6. Compute AI and OT FOC residuals
            """
            foc_ai = self.foc_cobb_douglas_FE(a_ai, d_it_rec, self.l_ai, self.w_ai, fe_ai, alpha)
            foc_ot = self.foc_cobb_douglas_FE(a_ot, d_it_rec, self.l_ot, self.w_ot, fe_ot, gamma)
            if Debug:
                diff_ai = (np.abs(foc_ai_pre - foc_ai)).sum()
                diff_ot = (np.abs(foc_ot_pre - foc_ot)).sum()
                print('The sum of absolute differences in scaled and unscaled FOC is:')
                print('- AI:', float(diff_ai))
                print('- OT:', float(diff_ot))
        elif self.FE == 'A':
            """
            5a. Compute A_AI and A_OT
                - <level> here can be: 'i', 't', 'i,t', chosen: 't'
                - <mode> here can be 'av', 'lin', chosen: 'av'
            """
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t')
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t')
            """
            6. Compute AI and OT FOC residuals
            """
            foc_ai = self.foc_cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha)
            foc_ot = self.foc_cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma)
        elif self.FE == 'a':
            """
            5a. Compute a_AI and a_OT
                - <level> here can be: 'i', 't', 'i,t', chosen: 'i'
                - <mode> here can be 'av', 'lin', chosen: 'av'
            """
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t')
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t')
            """
            6. Compute AI and OT FOC residuals
            """
            foc_ai = self.foc_cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha)
            foc_ot = self.foc_cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma)
        elif self.FE == 'none':
            """
            6. Compute AI and OT FOC residuals
            """
            a_ai = (self.emp * 0 + 1).rename('a_ai')
            a_ot = (self.emp * 0 + 1).rename('a_ot')
            foc_ai = self.foc_cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha)
            foc_ot = self.foc_cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma)
        else:
            raise ValueError('The only supported values for FE are: both, none, A, a.')
        """
        7. Compile and return residuals vector
        """
        residuals = np.concatenate((foc_ai, foc_ot))
        residuals = np.concatenate((foc_dm, residuals))
        residuals = np.concatenate((data_process, residuals))
        return residuals.astype(float)

    def run_outcome(self, alpha, gamma, phi, delta, d_0_av):
        """
        Given the estimated parameters it returns the run's characteristics
        :param alpha: production function exponents on data for AI type
        :param gamma: production function exponents on data for old tech type
        :param phi: production function exponents on data for data management type
        :param delta: depreciation rate
        :param d_0_av: average initial data stock at firm level
        :return: a dataframe that saves all variables of interest for estimation
        """
        """
        1. Calculate basic intermediate steps
        """
        sub_d_it = self.dit_inner(alpha, gamma, phi, delta, 'd_it')
        zero_it, rec_it = self.dit_emp(d_0_av, phi, delta)
        ups, d_i0 = self.dit_emp(d_0_av, phi, delta, zero=True)
        """
        2. Compute D_it_recursive and from foc_dm 
        """
        d_it_rec = zero_it + rec_it
        d_it_rec = d_it_rec.rename('d_it_rec')
        d_it = sub_d_it * ((self.l_dm**(-phi)) * (self.w_dm**(-1)))
        d_it = d_it.rename('d_it')
        """
        3. Compute AI and OT K based on FOCs
        """
        k_dm = (sub_d_it * (((self.l_dm**(1 - phi))) / d_it_rec))
        k_dm = k_dm.rename('k_dm')
        w_dm_total = self.w_dm * self.l_dm
        w_dm_total = w_dm_total.rename('w_dm_total')
        if self.FE == 'both':
            """
            4. Compute A_AI and A_OT jointly with FE_AI and FE_OT
                - <level> here can be: 'i', 't', 'i,t'
                - <mode> here can be 'av', 'lin'
            """
            # Initialize fixed-effects
            fe_ai = (self.emp * 0 + 1 / self.AI_emp).rename('fe_ai')
            fe_ot = (self.emp * 0 + 1 / self.OT_emp).rename('fe_ot')
            # Compute initial productivity parameters
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t', f=fe_ai)
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t', f=fe_ot)
            # Initialize difference parameters
            f_ai_old = fe_ai + 1000
            f_ot_old = fe_ot + 1000
            a_ai_old = a_ai + 1000
            a_ot_old = a_ot + 1000
            diff = ((fe_ai - f_ai_old)**2).sum() + ((fe_ot - f_ot_old)**2).sum() + \
                   float(((a_ai - a_ai_old)**2).sum()) + float(((a_ot - a_ot_old)**2).sum())
            # Update fixed effects and productivity parameters iteratively
            i = 0
            while diff > self.FE_tol:
                fe_ai, df_fe_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'fe_ai', level='i', f=a_ai)
                fe_ot, df_fe_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'fe_ot', level='i', f=a_ot)
                a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t', f=fe_ai)
                a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t', f=fe_ot)
                if i == 0:
                    diff = ((fe_ai.fe_ai - f_ai_old.reset_index().fe_ai) ** 2).sum() + \
                           ((fe_ot.fe_ot - f_ot_old.reset_index().fe_ot) ** 2).sum() + \
                           float(((a_ai - a_ai_old) ** 2).sum()) + float(((a_ot - a_ot_old) ** 2).sum())
                else:
                    diff = float(((fe_ai - f_ai_old) ** 2).sum()) + float(((fe_ot - f_ot_old) ** 2).sum()) + \
                          float(((a_ai - a_ai_old) ** 2).sum()) + float(((a_ot - a_ot_old) ** 2).sum())
                f_ai_old = fe_ai
                f_ot_old = fe_ot
                a_ai_old = a_ai
                a_ot_old = a_ot
                i += 1
            sum_fe_ai = float(df_fe_ai.drop_duplicates('emp').fe_ai.sum())
            fe_ai = fe_ai / sum_fe_ai
            a_ai = a_ai * sum_fe_ai
            sum_fe_ot = float(df_fe_ot.drop_duplicates('emp').fe_ot.sum())
            fe_ot = fe_ot / sum_fe_ot
            a_ot = a_ot * sum_fe_ot
            """
            5. Compute AI and OT FOC residuals
            """
            k_ai = self.cobb_douglas_fe(a_ai, d_it_rec, self.l_ai, self.w_ai, fe_ai, alpha, 'k_ai')
            k_ot = self.cobb_douglas_fe(a_ot, d_it_rec, self.l_ot, self.w_ot, fe_ot, gamma, 'k_ot')
        elif self.FE == 'A':
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t')
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t')
            """
            5. Compute AI and OT FOC residuals
            """
            k_ai = self.cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha, 'k_ai')
            k_ot = self.cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma, 'k_ot')
        elif self.FE == 'a':
            """
            4a. Compute a_AI and a_OT
                - <level> here can be: 'i', 't', 'i,t', chosen: 'i'
                - <mode> here can be 'av', 'lin', chosen: 'av'
            """
            a_ai = self.productivity(d_it_rec, self.l_ai, self.w_ai, alpha, 'a_ai', level='t')
            a_ot = self.productivity(d_it_rec, self.l_ot, self.w_ot,  gamma, 'a_ot', level='t')
            """
            5. Compute AI and OT FOC residuals
            """
            k_ai = self.cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha, 'k_ai')
            k_ot = self.cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma, 'k_ot')
        elif self.FE == 'none':
            """
            5. Compute AI and OT FOC residuals
            """
            a_ai = (self.emp * 0 + 1).rename('a_ai')
            a_ot = (self.emp * 0 + 1).rename('a_ot')
            k_ai = self.cobb_douglas(a_ai, d_it_rec, self.l_ai, self.w_ai, alpha, 'k_ai')
            k_ot = self.cobb_douglas(a_ot, d_it_rec, self.l_ot, self.w_ot, gamma, 'k_ot')
        else:
            raise ValueError('The only supported values for FE are: both, none, A, a.')

        # Cumulative salaries
        w_ai_total = self.w_ai * self.l_ai
        w_ai_total = w_ai_total.rename('w_ai_total')
        w_ot_total = self.w_ot * self.l_ot
        w_ot_total = w_ot_total.rename('w_ot_total')

        """
        6. Compile and return final dataframe
        """
        df = pd.merge(self.yearmonth, self.emp, left_index=True, right_index=True)
        df = df.merge(k_dm, left_index=True, right_index=True)
        df = df.merge(k_ai, left_index=True, right_index=True)
        df = df.merge(k_ot, left_index=True, right_index=True)
        df = df.merge(d_it, left_index=True, right_index=True)
        df = df.merge(d_it_rec, left_index=True, right_index=True)
        df = df.merge(w_dm_total, left_index=True, right_index=True)
        df = df.merge(w_ai_total, left_index=True, right_index=True)
        df = df.merge(w_ot_total, left_index=True, right_index=True)
        df = df.merge(ups, left_index=True, right_index=True)
        df = df.merge(d_i0, left_index=True, right_index=True)
        if self.FE == 'both':
            df = df.merge(a_ai, left_index=True, right_index=True)
            df = df.merge(a_ot, left_index=True, right_index=True)
            df = df.merge(fe_ai, left_index=True, right_index=True)
            df = df.merge(fe_ot, left_index=True, right_index=True)
        elif self.FE == 'A':
            df = df.merge(a_ai, left_index=True, right_index=True)
            df = df.merge(a_ot, left_index=True, right_index=True)
        elif self.FE == 'a':
            a_ai = a_ai.rename(columns={'a_ai': 'fe_ai'})
            a_ot = a_ot.rename(columns={'a_ot': 'fe_ot'})
            df = df.merge(a_ai, left_index=True, right_index=True)
            df = df.merge(a_ot, left_index=True, right_index=True)
        return df

    def v_0(self, df):
        """
        Calculate the initial value of data stock
        :param df: outcome dataframe that saves variables of interest for estimation
        :return: initial value of data V(D_i0)
        """
        return df.k_ai.fillna(0) + df.k_ot.fillna(0) - df.w_ai_total.fillna(0) - df.w_ot_total.fillna(0) - \
               df.w_dm_total.fillna(0)

    def d_t1(self, df, delta, phi, data_var):
        """
        Calculate dynamic data stock for period t+1 given data stock for period t
        :param df: outcome dataframe that saves variables of interest
        :param delta: depreciation rate
        :param phi: production function exponents on data for data management type
        :return: initial value of data V(D_i0)
        """
        return (1 - delta) * df[data_var] + (df.l_dm ** (1 - phi))

    def bellman(self, w, df, delta, phi, data_var):
        """
        The approximate Bellman operator.
        Parameters: w is a LinInterp object (i.e., a
        callable object which acts pointwise on arrays).
        Returns: An instance of LinInterp that represents the optimal operator.
        w is a function defined on the state space.
        """
        d_t1 = self.d_t1(df, delta, phi, data_var)
        h = self.v_0(df) + w(d_t1) / self.r
        # h = np.abs(h)
        return LinInterp(df[data_var], h)

    def value_fun_iteration(self, df, delta, phi, maxiter=1000000, tol=1e-6, printing=True):
        """
        Calculate equilibrium value of data stock by iteration
        :param df: outcome dataframe that saves variables of interest for estimation
        :param delta: depreciation rate
        :param phi: production function exponents on data for data management type
        :param maxiter: maximum iteration times
        :param tol: convergence threshold for residuals
        :param printing: option indicates whether output is printed or not
        """
        if printing:
            print('Starting Value function iteration...')
        df = df.sort_values('d_it_rec')
        df['guess'] = np.abs(self.v_0(df) + (self.d_t1(df, delta, phi, 'd_it_rec') / self.r))
        if printing:
            print('Guess distribution:')
            print(df['guess'].describe(percentiles=[0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008, 0.009, 0.01,
                                                    0.05, 0.10, 0.25, 0.5, 0.75, 0.90, 0.95, 0.99, 0.991, 0.992, 0.993,
                                                    0.994, 0.995, 0.996, 0.997, 0.998, 0.999]))
        V0 = LinInterp(df['d_it_rec'], df['guess'])
        V0 = self.bellman(V0, df, delta, phi, 'd_it_rec')
        count = 0
        while count < maxiter:
            V1 = self.bellman(V0, df, delta, phi, 'd_it_rec')
            err = np.max(np.abs(np.array(V1(df['d_it_rec'])) - np.array(V0(df['d_it_rec']))))
            V0 = V1
            count += 1
            if count % 100 == 0:
                if printing:
                    print('Iteration number: ', str(count), '. Error:', str(err))
            if err < tol:
                if printing:
                    print('The number of iteration until convergence was:', count)
                break
        if printing:
            print('The error of the final iteration is:', err)
        return err, V1
